package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.IDF
import org.apache.spark.ml.feature.Tokenizer
import org.apache.spark.ml.linalg.Vector

object IDFTest extends BaseTest {
    
    def apply(): BaseSpark = new IDFTest()
    
}

class IDFTest extends BaseSpark{
    
    override def getDataFrame(sparkSession: SparkSession = this.getSparkSession()): DataFrame = {
         sparkSession.createDataFrame(Seq(
              (0, "A A B B"),
              (1, "A A C C"),
              (2, "A B B C D")
         ))
         .toDF("id", "sentence")
     }
    
    override def execute(dataFrame: DataFrame) = {
        //分词
        val wordsData = new Tokenizer()
        .setInputCol("sentence")
        .setOutputCol("words")
        .transform(dataFrame)
        //特征hash
        var hashingTFDF = new HashingTF()
        .setInputCol("words")             //待变换的特征
        .setOutputCol("words_hash")       //变换后的特征名称
        .setBinary(false)                 //若果设置为true，则所有非0或非空的数据会被装换为1
        .setNumFeatures(262144)           //支持最大的特征数量，默认：262144
        .transform(wordsData)
        //IDF
        val idf = new IDF()
        .setInputCol("words_hash")        //待变换的特征
        .setOutputCol("words_idf")        //变换后的特征名称
        .setMinDocFreq(1)                 //在整个数据集中，低于阈值的重要性被设置为0
        //模型训练
        val idfModel = idf.fit(hashingTFDF)  
        //转换
        val rescaledData = idfModel.transform(hashingTFDF)  
        //show
        rescaledData.show(10, 1000)
        rescaledData.printSchema()
    }
    
}
